Setup

workflow_name <- "netnav_06_model_comparison"

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.1
## ✔ ggplot2   3.4.4     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.1
## ✔ purrr     1.0.2     
## ── Conflicts ────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(here)
## here() starts at /Users/jaeyoungson/Documents/GitHub/network-navigation-replay
library(patchwork)

library(glmmTMB)
## Warning in checkDepPackageVersion(dep_pkg = "TMB"): Package version inconsistency detected.
## glmmTMB was built with TMB version 1.9.6
## Current TMB version is 1.9.10
## Please re-install glmmTMB from source or restore original 'TMB' package (see '?reinstalling' for more information)
library(broom.mixed)

source(here("code", "utils", "modeling_utils.R"))
source(here("code", "utils", "representation_utils.R"))
source(here("code", "utils", "bayesian_model_selection.R"))

source(here("code", "utils", "ggplot_themes.R"))
source(here("code", "utils", "kable_utils.R"))
## 
## Attaching package: 'kableExtra'
## 
## The following object is masked from 'package:dplyr':
## 
##     group_rows
source(here("code", "utils", "unicode_greek.R"))

knitting <- knitr::is_html_output()

create_path <- function(this_path) {
  if (!dir.exists(this_path)) {
    dir.create(this_path, recursive = TRUE)
  }
}

predict_glmmTMB <- function(make_predictions_for, model_object) {
  make_predictions_for %>%
    bind_cols(
      predict(
        object = model_object,
        newdata = .,
        re.form = NA, allow.new.levels = TRUE, se.fit = TRUE, type = "response"
      )
    )
}

if (knitting) {
  here("outputs", workflow_name) %>%
    create_path()
  
  here("figures") %>%
    create_path()
}
nav_study1 <- here("data", "clean_data", "study1_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    # two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 1",
    measurement_id = str_c("D", measurement_id),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt,
    two_correct_options,
    opt1_distance = dist_opt1,
    opt2_distance = dist_opt2
  )

nav_study2 <- here("data", "clean_data", "study2_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    # two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 2",
    measurement_id = case_when(
      network == "learned" ~ str_c("D", measurement_id),
      network == "reevaluated" ~ "D2b"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt,
    two_correct_options,
    opt1_distance = dist_opt1,
    opt2_distance = dist_opt2
  )

nav_study3 <- here("data", "clean_data", "study3_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    # two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 3",
    measurement_id = case_when(
      network == "reevaluated" ~ "D2b",
      measurement_id == 1 ~ "D1",
      measurement_id == 2 ~ "D1b",
      measurement_id == 3 ~ "D2"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt,
    two_correct_options,
    opt1_distance = dist_opt1,
    opt2_distance = dist_opt2
  )
bfs_backward_sims <- here(
  "data", "bfs_sims", "bfs_sims_learned_backward.csv"
) %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    shortest_path_given_opts == shortest_path_given_start_end,
    # two_correct_options == FALSE
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  select(-starts_with("shortest_path_given")) %>%
  group_by(
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    correct_choice, two_correct_options
  ) %>%
  summarise(
    p_bfs_correct = mean(bfs_choice == correct_choice),
    p_bfs_chooses_opt1 = mean(bfs_choice == opt1_id),
    bfs_visits = mean(bfs_n_visits_total),
    .groups = "drop"
  )

bfs_forward_sims <- here("data", "bfs_sims", "bfs_sims_learned_forward.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    shortest_path_given_opts == shortest_path_given_start_end,
    # two_correct_options == FALSE
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  select(-starts_with("shortest_path_given")) %>%
  group_by(
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    correct_choice, two_correct_options
  ) %>%
  summarise(
    p_bfs_correct = mean(bfs_choice == correct_choice),
    p_bfs_chooses_opt1 = mean(bfs_choice == opt1_id),
    bfs_visits = mean(bfs_n_visits_total),
    .groups = "drop"
  )

nav_trials <- here("data", "clean_data", "study1_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    # two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  filter(sub_id == 1) %>%
  select(
    shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice,
    opt1_distance = dist_opt1,
    opt2_distance = dist_opt2,
    two_correct_options
  ) %>%
  arrange(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  # Replace undefined distances (corresponding to impossible options)
  # so that the softmax gets non-NA inputs; we assume that impossible
  # options are just as bad as the longest distance found in this set
  # of trials, i.e., a distance of 8
  mutate(across(c(opt1_distance, opt2_distance), ~replace_na(.x, 8)))

adjlist <- here("data", "clean_data", "adjlist_learned.csv") %>%
  read_csv(show_col_types = FALSE)

transmat <- adjlist %>%
  group_by(from) %>%
  mutate(edge = edge / sum(edge)) %>%
  ungroup() %>%
  pivot_wider(names_from = to, values_from = edge) %>%
  column_to_rownames("from") %>%
  as.matrix()
load_params_from_scratch <- FALSE

if (load_params_from_scratch == TRUE) {
  params <- here("data", "param_fits") %>%
    fs::dir_ls(
      recurse = 1,
      regexp = str_c(
        "_",
        "(bfs_(backward|forward)|",
        "ideal_obs|",
        "sr_(analytic|delta_rule))_",
        "(.)+\\.csv"
      )
    ) %>%
    map_dfr(
      .f = ~read_csv(.x, show_col_types = FALSE),
      .id = "filename"
    ) %>%
    mutate(
      # Recover model ID
      model = str_extract(
        filename,
        str_c(
          "_",
          "(bfs_(backward|forward)|",
          "ideal_obs|",
          "sr_(analytic|delta_rule))_"
        )
      ),
      model = str_sub(model, 2, -2),
      # Recover study ID
      study = str_extract(filename, "study[[:digit:]]"),
      study = str_replace(study, "study", "Study "),
      # Recover subject ID
      sub_id = str_extract(filename, "sub_[[:digit:]]+"),
      sub_id = str_remove(sub_id, "sub_"),
      sub_id = as.integer(sub_id),
      # Recover measurement ID
      measurement_id = str_extract(filename, "_D[[:digit:]]b?"),
      measurement_id = str_remove(measurement_id, "_"),
      # Get parameter values
      param_value = if_else(
        is.na(param_value_human_readable),
        param_value,
        param_value_human_readable
      )
    ) %>%
    # Find best-fitting optimization run
    filter(convergence == "converged") %>%
    group_by(model, study, sub_id, measurement_id) %>%
    slice_min(optim_value, n = 1) %>%
    ungroup() %>%
    # Some subjects may have had multiple "best" optimization runs
    # In that case, just go with whichever "best" run was estimated first
    group_by(model, study, sub_id, measurement_id) %>%
    slice_min(optimizer_run, n = 1) %>%
    ungroup() %>%
    # Clean up
    select(
      model, study, sub_id, measurement_id,
      param_name, param_value,
      neg_loglik = optim_value
    ) %>%
    arrange(model, study, sub_id, measurement_id, param_name)
  
  here("data", "param_fits", "clean_params") %>%
    create_path()
  
  params %>%
    write_csv(
      file = here("data", "param_fits", "clean_params", "clean_param_fits.csv")
    )
}

params <- here("data", "param_fits", "clean_params", "clean_param_fits.csv") %>%
  read_csv(show_col_types = FALSE)

AICc

As our metric of log-evidence, we’ll use AICc, i.e., AIC corrected for a relatively small N.

aicc <- params %>%
  select(study, sub_id, measurement_id, model, neg_loglik) %>%
  distinct() %>%
  arrange(study, sub_id, measurement_id) %>%
  mutate(
    n_params = if_else(model == "sr", 2, 1),
    n_datapoints = 115,
    aic = (-2 * -neg_loglik) + (2 * n_params),
    aicc = aic + (
      (2 * n_params * (n_params + 1)) / (n_datapoints - n_params - 1)
    )
  )

SR analytic vs delta-rule

Before we go any further, let’s just get one targeted comparison out of the way. In earlier scripts, we saw that SR matrices can be constructed using a closed-form analytic solution, or a delta-rule updating mechanism. These different implementations can, in principle, end up making very different predictions. Here, we’ll see whether there’s evidence that one implementation fits better than the other.

Below, we’re directly comparing the AICc of the analytic vs delta-rule implementations. Each datapoint is one subject, and the lines connect a subject’s AICc from one implementation to the other. The red bars reflect the means. The lines are remarkably flat, indicating that there is functionally no real difference in the model goodness-of-fit.

plot_sr_comparison_aicc <- aicc %>%
  filter(str_detect(model, "sr_")) %>%
  mutate(facet_label = str_c(study, ", ", measurement_id)) %>%
  ggplot(aes(x=model, y=aicc)) +
  theme_custom() +
  facet_wrap(~facet_label, scales = "free_x") +
  geom_point(alpha = 0.5) +
  geom_line(aes(group = sub_id), alpha = 0.2) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(
    name = "SR implementation",
    labels = c("sr_analytic"="Analytic", "sr_delta_rule"="Delta-rule")
  ) +
  ylab("AICc") +
  ggtitle("AICc comparison of SR implementations")

plot_sr_comparison_aicc

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "sr_aicc_analytic_vs_delta_rule.pdf"
    ),
    plot = plot_sr_comparison_aicc,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}

We’re not interested in doing any sort of hypothesis testing here, but for the purpose of doing model selection, we do want to know whether choosing one implementation over the other might result in forming different conclusions about the parameter fits.

Below, the bars reflect medians. We can see that, by-and-large, the two implementations result in very similar estimates.

plot_sr_comparison_gamma <- params %>%
  filter(str_detect(model, "sr_")) %>%
  filter(param_name == "sr_gamma") %>%
  select(model, study, sub_id, measurement_id, sr_gamma = param_value) %>%
  ggplot(aes(x=measurement_id, y=sr_gamma, color=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_point(
    alpha = 0.5, position = position_dodge(width = 0.25), show.legend = FALSE
  ) +
  geom_line(aes(group = interaction(model, sub_id)), alpha = 0.2) +
  stat_summary(geom = "crossbar", fun = median) +
  xlab("Measurement") +
  ylab("SR gamma") +
  scale_color_manual(
    name = "SR implementation",
    labels = c("sr_analytic"="Analytic", "sr_delta_rule"="Delta-rule"),
    values = c("sr_analytic"="#ca0020", "sr_delta_rule"="#0571b0")
  ) +
  ggtitle("Estimates of SR gamma: Analytic vs delta-rule") +
  theme(legend.position = "bottom")

plot_sr_comparison_gamma
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "sr_gamma_analytic_vs_delta_rule.pdf"
    ),
    plot = plot_sr_comparison_gamma,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?

In the modeling, the goal was to test changes in the gamma parameter, assuming an asymptotic representation. These results suggest that there is nothing lost by using the analytic closed-form implementation, so we’ll stick with that from here onwards.

Description of parameter fits

params %>%
  group_by(model, study, measurement_id, param_name) %>%
  summarise(
    param_mean = mean(param_value),
    param_median = median(param_value),
    .groups = "drop"
  ) %>%
  kable_custom(
    captions = "Descriptive stats: parameter fits",
    grouping_var = model
  )
Descriptive stats: parameter fits
study measurement_id param_name param_mean param_median
bfs_backward
Study 1 D1 search_threshold 3.834 3.659
Study 2 D1 search_threshold 4.769 4.431
Study 2 D2 search_threshold 5.319 4.762
Study 3 D1 search_threshold 5.072 4.535
Study 3 D1b search_threshold 4.870 4.144
Study 3 D2 search_threshold 5.763 4.652
bfs_forward
Study 1 D1 search_threshold 5.926 6.038
Study 2 D1 search_threshold 7.438 6.854
Study 2 D2 search_threshold 8.116 7.168
Study 3 D1 search_threshold 7.783 6.625
Study 3 D1b search_threshold 7.611 6.606
Study 3 D2 search_threshold 8.618 7.404
ideal_obs
Study 1 D1 softmax_temperature -0.272 -0.184
Study 2 D1 softmax_temperature -0.475 -0.289
Study 2 D2 softmax_temperature -0.685 -0.353
Study 3 D1 softmax_temperature -0.687 -0.284
Study 3 D1b softmax_temperature -0.696 -0.319
Study 3 D2 softmax_temperature -1.024 -0.380
sr_analytic
Study 1 D1 softmax_temperature 1797.497 26.023
Study 1 D1 sr_gamma 0.567 0.647
Study 2 D1 softmax_temperature 285.337 24.286
Study 2 D1 sr_gamma 0.659 0.722
Study 2 D2 softmax_temperature 560.735 30.060
Study 2 D2 sr_gamma 0.722 0.826
Study 3 D1 softmax_temperature 688.291 31.783
Study 3 D1 sr_gamma 0.690 0.816
Study 3 D1b softmax_temperature 761.781 40.193
Study 3 D1b sr_gamma 0.673 0.825
Study 3 D2 softmax_temperature 2041.139 54.912
Study 3 D2 sr_gamma 0.718 0.834
sr_delta_rule
Study 1 D1 softmax_temperature 45.047 12.202
Study 1 D1 sr_gamma 0.557 0.639
Study 2 D1 softmax_temperature 32.039 13.713
Study 2 D1 sr_gamma 0.649 0.709
Study 2 D2 softmax_temperature 119.106 21.976
Study 2 D2 sr_gamma 0.714 0.819
Study 3 D1 softmax_temperature 105.173 18.794
Study 3 D1 sr_gamma 0.680 0.790
Study 3 D1b softmax_temperature 40.854 15.469
Study 3 D1b sr_gamma 0.667 0.813
Study 3 D2 softmax_temperature 278.951 24.712
Study 3 D2 sr_gamma 0.716 0.835
plot_params_bfs_backward <- params %>%
  filter(model == "bfs_backward") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=search_threshold)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Search threshold") +
  ggtitle("BFS-backward")

plot_params_bfs_forward <- params %>%
  filter(model == "bfs_forward") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=search_threshold)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Search threshold") +
  ggtitle("BFS-forward")

plot_params_ideal_obs <- params %>%
  filter(model == "ideal_obs") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=softmax_temperature)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Inverse temperature") +
  ggtitle("Ideal observer")

plot_params_sr_temp <- params %>%
  filter(model == "sr_analytic") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=softmax_temperature)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Inverse temperature") +
  ggtitle("Successor Representation") +
  coord_cartesian(ylim = c(0, 3000))

plot_params_sr_gamma <- params %>%
  filter(model == "sr_analytic") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=sr_gamma)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Gamma") +
  ggtitle("Successor Representation")

plot_params_all <- (
  (plot_params_bfs_backward | plot_params_bfs_forward | plot_params_ideal_obs) /
    (plot_params_sr_temp | plot_params_sr_gamma)
) +
  plot_annotation(
    title = "Estimated parameters",
    tag_levels = "A", tag_suffix = ".",
    theme = theme(plot.title = element_text(hjust = 0.5))
  )

plot_params_all
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "param_estimates.pdf"),
    plot = plot_params_all,
    width = 12, height = 6,
    units = "in", dpi = 300
  )
}
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?

Akaike weights

We’ll later use protected exceedance probabilities (PXP) to do formal inference to test whether a particular model provides a significantly better group-level fit than other models. But first, we do want to acknowledge that there’s likely to be some individual differences in how well a particular model fits each subject. To get a sense for this, we’ll use Akaike weights, which provide the probability that a particular model is the “best” given the data and the set of candidate models.

akaike_weights <- aicc %>%
  filter(model != "sr_delta_rule") %>%
  group_by(study, measurement_id, sub_id) %>%
  mutate(
    relative_likelihood = exp(-1/2 * (aicc - min(aicc))),
    akaike_weight = relative_likelihood / sum(relative_likelihood),
    evidence_ratio = max(akaike_weight) / akaike_weight
  ) %>%
  ungroup() %>%
  arrange(study, measurement_id, sub_id, evidence_ratio)

We can first average over all subjects’ Akaike weights to get a sense for what the “best-fitting” model is across subjects. This suggests that the SR consistently comes out on top, followed pretty consistently by BFS-backward.

plot_akaike_group <- akaike_weights %>%
  group_by(study, measurement_id, model) %>%
  summarise(akaike_weight = mean(akaike_weight), .groups = "drop") %>%
  mutate(text = round(akaike_weight, 2)) %>%
  ggplot(aes(x=measurement_id, y=akaike_weight, fill=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_col() +
  geom_text(aes(label = text), position = position_stack(vjust = 0.5)) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(
    name = NULL,
    expand = expansion(mult = c(0.01, 0.01))
  ) +
  scale_fill_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(
    legend.position = "bottom",
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  ggtitle("Akaike weights")

plot_akaike_group

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "akaike_weights_group.pdf"),
    plot = plot_akaike_group,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}

We can break this out and plot each individual subject’s Akaike weights.

plot_akaike_individual <- akaike_weights %>%
  mutate(
    sub_id = factor(sub_id),
    measurement_id = case_when(
      measurement_id == "D1" ~ "before rest",
      measurement_id == "D1b" ~ "after awake rest",
      measurement_id == "D2" ~ "after overnight rest"
    ),
    study = str_c(study, ", ", measurement_id),
    study = fct_relevel(
      study,
      "Study 1, before rest",
      "Study 2, before rest",
      "Study 2, after overnight rest",
      "Study 3, before rest",
      "Study 3, after awake rest",
      "Study 3, after overnight rest"
    )
  ) %>%
  ggplot(aes(x=sub_id, y=akaike_weight, fill=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x", ncol = 1) +
  geom_col() +
  scale_x_discrete(name = "Subject ID") +
  scale_y_continuous(
    name = NULL,
    expand = expansion(mult = c(0.01, 0.01))
  ) +
  scale_fill_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(
    legend.position = "bottom",
    axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1),
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  ggtitle("Akaike weights")

plot_akaike_individual

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "akaike_weights_per_subject.pdf"
    ),
    plot = plot_akaike_individual,
    width = 8, height = 10,
    units = "in", dpi = 300
  )
}

Akaike weights provide a nice goodness-of-fit metric that respects the probabilistic aspect of model comparison, and can do so at both the group- and individual-level. However, model selection requires us to ultimately make a discrete choice. If we made per-subject decisions based simply by choosing the single best-fitting model, what proportion of subjects are best-fit by each model? We see that the pattern of results basically mirrors what we’d seen in the Akaike weights, such that the SR is the best-fitting model for the majority of subjects, followed by BFS-backward.

best_fitting_model_per_sub <- akaike_weights %>%
  group_by(study, measurement_id, sub_id) %>%
  slice_max(akaike_weight) %>%
  ungroup() %>%
  select(study, measurement_id, sub_id, best_fitting_model = model)

plot_best_fitting_model_prop <- best_fitting_model_per_sub %>%
  count(study, measurement_id, best_fitting_model) %>%
  group_by(study, measurement_id) %>%
  mutate(
    p = n / sum(n),
    text = str_c(round(p, 2) * 100, "%")
  ) %>%
  ggplot(aes(x=measurement_id, y=p, fill=best_fitting_model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_col() +
  geom_text(aes(label = text), position = position_stack(vjust = 0.5)) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(
    name = NULL,
    expand = expansion(mult = c(0.01, 0.01))
  ) +
  scale_fill_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(
    legend.position = "bottom",
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  ggtitle("Proportion of subjects best fit by each model")

plot_best_fitting_model_prop

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "prop_subjects_best_fit.pdf"
    ),
    plot = plot_best_fitting_model_prop,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}

Protected exceedance probabilities

Protected exceedance probabilities provide a formal test of a model’s group-level fit compared to other candidate models. To run this analysis, we’ll adapt software originally written by Matteo Lisi (https://github.com/mattelisi/bmsR). We’ll use AICc as our metric of log-evidence.

pxp_results <- aicc %>%
  filter(model != "sr_delta_rule") %>%
  # In PXP, more is more. AICc, in contrast, is based off neg-LL, and so is
  # interpreted as "smaller is better". So, do a sign flip.
  mutate(aicc = -aicc) %>%
  select(study, measurement_id, sub_id, model, aicc) %>%
  # Compute PXP for each study/measurement
  pivot_wider(names_from = model, values_from = aicc) %>%
  select(-sub_id) %>%
  group_by(study, measurement_id) %>%
  nest() %>%
  mutate(
    test = map(
      .x = data,
      .f = ~bayesian_model_selection(.x)
    )
  ) %>%
  ungroup() %>%
  unnest(test) %>%
  select(-data)

In the results, it’s clear that the SR comes out on top, and by a large margin.

pxp_results %>%
  mutate(
    measurement_id = case_when(
      measurement_id == "D1" ~ "before rest",
      measurement_id == "D1b" ~ "after awake rest",
      measurement_id == "D2" ~ "after overnight rest"
    ),
    study = str_c(study, ", ", measurement_id),
    study = fct_relevel(
      study,
      "Study 1, before rest",
      "Study 2, before rest",
      "Study 2, after overnight rest",
      "Study 3, before rest",
      "Study 3, after awake rest",
      "Study 3, after overnight rest"
    )
  ) %>%
  select(-measurement_id) %>%
  arrange(study, desc(pxp)) %>%
  kable_custom("PXP results", grouping_var = study)
PXP results
model alpha expected_model_frequencies omnibus_risk xp pxp
Study 1, before rest
sr_analytic 36.549 0.677 0.000 0.999459 9.994587e-01
bfs_backward 14.010 0.259 0.000 0.000541 5.410882e-04
bfs_forward 1.475 0.027 0.000 0.000000 8.835920e-08
ideal_obs 1.967 0.036 0.000 0.000000 8.835920e-08
Study 2, before rest
sr_analytic 28.381 0.526 0.011 0.983780 9.758328e-01
bfs_backward 14.577 0.270 0.011 0.015947 1.848192e-02
ideal_obs 8.476 0.157 0.011 0.000273 2.977680e-03
bfs_forward 2.566 0.048 0.011 0.000000 2.707637e-03
Study 2, after overnight rest
sr_analytic 32.103 0.595 0.000 0.998161 9.981347e-01
bfs_backward 13.044 0.242 0.000 0.001818 1.826722e-03
ideal_obs 7.667 0.142 0.000 0.000021 2.978537e-05
bfs_forward 1.186 0.022 0.000 0.000000 8.786111e-06
Study 3, before rest
sr_analytic 26.939 0.539 0.005 0.976547 9.726377e-01
bfs_backward 14.387 0.288 0.005 0.023440 2.465903e-02
ideal_obs 5.363 0.107 0.005 0.000012 1.357085e-03
bfs_forward 3.312 0.066 0.005 0.000001 1.346145e-03
Study 3, after awake rest
sr_analytic 32.674 0.653 0.001 0.999898 9.994390e-01
ideal_obs 9.419 0.188 0.001 0.000102 2.549721e-04
bfs_backward 5.342 0.107 0.001 0.000000 1.530346e-04
bfs_forward 2.565 0.051 0.001 0.000000 1.530346e-04
Study 3, after overnight rest
sr_analytic 30.775 0.615 0.001 0.999712 9.988405e-01
ideal_obs 9.425 0.188 0.001 0.000237 5.273355e-04
bfs_backward 7.940 0.159 0.001 0.000051 3.415517e-04
bfs_forward 1.861 0.037 0.001 0.000000 2.906110e-04
plot_pxp <- pxp_results %>%
  mutate(
    text = round(pxp, 2) * 100,
    text = if_else(model != "sr_analytic", "", str_c(text, "%"))
  ) %>%
  ggplot(aes(x=measurement_id, y=pxp, fill=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_col() +
  geom_text(aes(label = text), position = position_stack(vjust = 0.5)) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(
    name = NULL,
    expand = expansion(mult = c(0.01, 0.01))
  ) +
  scale_fill_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(
    legend.position = "bottom",
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  ggtitle("Protected exceedance probabilities")

plot_pxp

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "pxp.pdf"
    ),
    plot = plot_pxp,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}

Posterior predictive check

Simulate predicted behaviors

It’s nice to see consistency in the quantitative model comparison and selection, but we’d also like to see how well human behaviors are qualitatively described by our models. To do this, we’ll simulate subjects’ predicted behaviors given their model parameters.

ppc_bfs_backward <- expand_grid(
  # List of all trials from BFS simulation for each subject/measurement
  params %>%
    select(study, sub_id, measurement_id) %>%
    distinct(),
  bfs_backward_sims
) %>%
  # Add subject-specific parameters
  left_join(
    params %>%
      filter(model == "bfs_backward") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, search_threshold),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # What's the probability of *completing* BFS-online all the way through?
  rowwise() %>%
  mutate(
    p_complete_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh BFS predictions accordingly
  mutate(
    p_give_up = 1 - p_complete_bfs,
    model_p_correct = (
      (p_complete_bfs * p_bfs_correct) + (p_give_up * 1/2)
    )
  ) %>%
  # Add subjects' actual choices
  left_join(
    bind_rows(nav_study1, nav_study2, nav_study3) %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        sub_choice, sub_correct = correct, sub_rt = rt
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  )
ppc_bfs_forward <- expand_grid(
  # List of all trials from BFS simulation for each subject/measurement
  params %>%
    select(study, sub_id, measurement_id) %>%
    distinct(),
  bfs_forward_sims
) %>%
  # Add subject-specific parameters
  left_join(
    params %>%
      filter(model == "bfs_forward") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, search_threshold),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # What's the probability of *completing* BFS-online all the way through?
  rowwise() %>%
  mutate(
    p_complete_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh BFS predictions accordingly
  mutate(
    p_give_up = 1 - p_complete_bfs,
    model_p_correct = (
      (p_complete_bfs * p_bfs_correct) + (p_give_up * 1/2)
    )
  ) %>%
  # Add subjects' actual choices
  left_join(
    bind_rows(nav_study1, nav_study2, nav_study3) %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        sub_choice, sub_correct = correct, sub_rt = rt
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  )
ppc_ideal_obs <- expand_grid(
  # List of all trials for each subject/measurement
  params %>%
    select(study, sub_id, measurement_id) %>%
    distinct(),
  nav_trials
) %>%
  # Add subject-specific parameters
  left_join(
    params %>%
      filter(model == "ideal_obs") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, softmax_temperature),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # Model predictions
  rowwise() %>%
  mutate(
    model_p_correct = softmax(
      option_values = c(opt1_distance, opt2_distance),
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE
    )
  ) %>%
  ungroup() %>%
  # Add subjects' actual choices
  left_join(
    bind_rows(nav_study1, nav_study2, nav_study3) %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        sub_choice, sub_correct = correct, sub_rt = rt
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  )
ppc_sr_representation <- params %>%
  filter(model == "sr_analytic") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  select(study, sub_id, measurement_id, sr_gamma) %>%
  rowwise() %>%
  mutate(
    predicted_sr = map(
      .x = sr_gamma,
      .f = ~build_successor_analytically(
        transmat, successor_horizon = .x, normalize = TRUE
      )
    )
  ) %>%
  ungroup() %>%
  select(study, sub_id, measurement_id, predicted_sr) %>%
  unnest(predicted_sr)

ppc_sr_navigation <- expand_grid(
  # List of all trials for each subject/measurement
  params %>%
    select(study, sub_id, measurement_id) %>%
    distinct(),
  nav_trials
) %>%
  # Add subject-specific parameters
  left_join(
    params %>%
      filter(model == "sr_analytic") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, softmax_temperature),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # Add SR predicted representation
  left_join(
    ppc_sr_representation %>%
      select(
        study, sub_id, measurement_id,
        endpoint_id = to, opt1_id = from, opt1_sr = sr_value
      ),
    by = join_by(study, sub_id, measurement_id, endpoint_id, opt1_id)
  ) %>%
  left_join(
    ppc_sr_representation %>%
      select(
        study, sub_id, measurement_id,
        endpoint_id = to, opt2_id = from, opt2_sr = sr_value
      ),
    by = join_by(study, sub_id, measurement_id, endpoint_id, opt2_id)
  ) %>%
  # Model navigation predictions
  rowwise() %>%
  mutate(
    model_p_correct = softmax(
      option_values = c(opt1_sr, opt2_sr),
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE
    )
  ) %>%
  ungroup() %>%
  # Add subjects' actual choices
  left_join(
    bind_rows(nav_study1, nav_study2, nav_study3) %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        sub_choice, sub_correct = correct, sub_rt = rt
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  )

Plot PPC

ppc_for_plotting <- params %>%
  select(study, sub_id, measurement_id) %>%
  distinct() %>%
  # Add human accuracy + BFS-backward accuracy
  left_join(
    ppc_bfs_backward %>%
      filter(two_correct_options == FALSE) %>%
      group_by(study, sub_id, measurement_id, shortest_path) %>%
      summarise(
        human = mean(sub_correct),
        bfs_backward = mean(model_p_correct),
        .groups = "drop"
      ),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # Add BFS-forward accuracy
  left_join(
    ppc_bfs_forward %>%
      filter(two_correct_options == FALSE) %>%
      group_by(study, sub_id, measurement_id, shortest_path) %>%
      summarise(
        bfs_forward = mean(model_p_correct),
        .groups = "drop"
      ),
    by = join_by(study, sub_id, measurement_id, shortest_path)
  ) %>%
  # Add ideal observer accuracy
  left_join(
    ppc_ideal_obs %>%
      filter(two_correct_options == FALSE) %>%
      group_by(study, sub_id, measurement_id, shortest_path) %>%
      summarise(
        ideal_obs = mean(model_p_correct),
        .groups = "drop"
      ),
    by = join_by(study, sub_id, measurement_id, shortest_path)
  ) %>%
  # Add SR accuracy
  left_join(
    ppc_sr_navigation %>%
      filter(two_correct_options == FALSE) %>%
      group_by(study, sub_id, measurement_id, shortest_path) %>%
      summarise(
        sr = mean(model_p_correct),
        .groups = "drop"
      ),
    by = join_by(study, sub_id, measurement_id, shortest_path)
  ) %>%
  # For plotting aesthetics
  pivot_longer(human:sr, names_to = "agent", values_to = "accuracy") %>%
  mutate(
    agent = case_when(
      agent == "human" ~ "Human",
      agent == "bfs_backward" ~ "BFS-backward",
      agent == "bfs_forward" ~ "BFS-forward",
      agent == "ideal_obs" ~ "Ideal observer",
      agent == "sr" ~ "Successor Rep."
    ),
    agent = fct_relevel(agent, "Human", "Successor Rep.")
  )

Let’s look at the main set of trials and compare human performance against the models on Day 1 (i.e., before rest).

plot_ppc_day1 <- ppc_for_plotting %>%
  filter(measurement_id == "D1") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  facet_wrap(~agent, nrow = 1) +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "blue") +
  geom_line(aes(group = interaction(study, sub_id)), alpha = 0.1) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  ggtitle("Posterior predictive check: Before rest")
  
plot_ppc_day1

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "ppc_day1.pdf"),
    plot = plot_ppc_day1,
    width = 6, height = 3,
    units = "in", dpi = 300
  )
}

We’ll do the same now for Day 2 (i.e., after overnight rest).

plot_ppc_day2 <- ppc_for_plotting %>%
  filter(measurement_id == "D2") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  facet_wrap(~agent, nrow = 1) +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "blue") +
  geom_line(aes(group = interaction(study, sub_id)), alpha = 0.1) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  ggtitle("Posterior predictive check: After overnight rest")
  
plot_ppc_day2

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "ppc_day2.pdf"),
    plot = plot_ppc_day2,
    width = 6, height = 3,
    units = "in", dpi = 300
  )
}

And finally, for Day 1b (i.e., after awake rest on Day 1), a measurement that was only made in Study 3.

plot_ppc_day1b <- ppc_for_plotting %>%
  filter(measurement_id == "D1b") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  facet_wrap(~agent, nrow = 1) +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "blue") +
  geom_line(aes(group = interaction(study, sub_id)), alpha = 0.1) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  ggtitle("Posterior predictive check: After awake rest")
  
plot_ppc_day1b

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "ppc_day1b.pdf"),
    plot = plot_ppc_day1b,
    width = 6, height = 3,
    units = "in", dpi = 300
  )
}

Held-out trials

The primary analyses, including the parameter-fitting, were performed on a set of trials where there was always one unambiguously correct answer. For this reason, there was also a subset of trials that got “held out” because the two options had the same shortest path distance from the target.

Several of the computational models (i.e., BFS-backward and ideal observer) therefore predict 50/50% choices on these trials. BFS-forward does not, as it allows for stochasticity in how agents perform searches from each of the two options. Most notably, the SR often predicts that an agent will prefer one option over another, which basically reflects the fact that (e.g.) although Sources A and B have the same shortest path distance to the Target, Source A might have a greater number of short paths to the Target.

Therefore, if subjects’ behaviors are consistent with non-random responding, we’d ideally like to see that the SR is able to make more accurate out-of-sample predictions on these trials.

heldout_likelihoods <- ppc_sr_navigation %>%
  filter(two_correct_options == TRUE) %>%
  select(
    study, sub_id, measurement_id,
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    opt1_sr, opt2_sr, sub_choice, softmax_temperature
  ) %>%
  mutate(
    sr_prefers = case_when(
      opt1_sr == opt2_sr ~ NA_real_,
      opt1_sr > opt2_sr ~ opt1_id,
      TRUE ~ opt2_id
    )
  ) %>%
  # Calculate the likelihood of the subject's choice, given what option
  # the SR would have preferred
  rowwise() %>%
  mutate(
    p_sub_choice = softmax(
      option_values = c(opt1_sr, opt2_sr),
      option_chosen = if_else(sr_prefers == opt1_id, 1, 2),
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE
    ),
    # Fix a few edge cases
    p_sub_choice = case_when(
      is.na(sr_prefers) ~ 0.5,
      is.nan(p_sub_choice) & (sub_choice == sr_prefers) ~ 1,
      # To avoid log(0), use machine epsilon
      is.nan(p_sub_choice) & (sub_choice != sr_prefers) ~ 2.22e-16,
      TRUE ~ p_sub_choice
    )
  ) %>%
  ungroup() %>%
  mutate(neg_ll_sr = neg_loglik_logistic(p_sub_choice)) %>%
  # Tidy up the SR bit
  select(
    study, sub_id, measurement_id,
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    sr_prefers, sub_choice,
    neg_ll_sr
  ) %>%
  # Add likelihoods for BFS-backward, which always predicts 50/50 responding
  mutate(neg_ll_bfs_backward = neg_loglik_logistic(0.5)) %>%
  # Add likelihoods for BFS-forward
  left_join(
    ppc_bfs_forward %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        p_bfs_chooses_opt1, p_complete_bfs, p_give_up
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  ) %>%
  mutate(
    p_sub_choice = case_when(
      is.na(sr_prefers) ~ 0.5,
      sr_prefers == opt1_id ~ (
        (p_complete_bfs * p_bfs_chooses_opt1) + (p_give_up * 1/2)
      ),
      sr_prefers == opt2_id ~ (
        (p_complete_bfs * (1-p_bfs_chooses_opt1)) + (p_give_up * 1/2)
      )
    ),
    neg_ll_bfs_forward = neg_loglik_logistic(p_sub_choice)
  ) %>%
  # Add likelihoods for ideal observer, which always predicts 50/50 responding
  mutate(neg_ll_ideal_obs = neg_loglik_logistic(0.5)) %>%
  # Tidy up
  select(
    study, sub_id, measurement_id,
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    sr_prefers, sub_choice,
    neg_ll_sr, neg_ll_bfs_backward, neg_ll_bfs_forward, neg_ll_ideal_obs
  )

Below, we see that the SR, compared to the other models, is doing a better job of explaining subjects’ choices on held-out trials, though we again note that the likelihoods for the BFS-backward and ideal observer models are for completely random responding. The different studies have different baseline log-likelihoods because they contain a different number of trials (i.e., in Studies 2-3, we’re summing over both Day 1 and Day 2 measurements).

plot_heldout_likelihoods <- heldout_likelihoods %>%
  # Sum so that we get one neg-loglik per subject
  group_by(study, sub_id, measurement_id) %>%
  summarise(across(starts_with("neg_ll_"), sum), .groups = "drop") %>%
  # Prep for plotting
  pivot_longer(
    starts_with("neg_ll_"), names_to = "model", values_to = "neg_ll"
  ) %>%
  mutate(model = str_remove(model, "neg_ll_"), loglik = -neg_ll) %>%
  ggplot(aes(x=measurement_id, y=loglik, color=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(
    geom = "crossbar", fun = median, position = position_dodge(width = 0.75)
  ) +
  geom_point(
    alpha = 0.2,
    position = position_jitterdodge(
      jitter.width = 0.1, jitter.height = 0, dodge.width = 0.75, seed = 1
    ),
    show.legend = FALSE
  ) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(name = "log-likelihood\n(greater = better)") +
  scale_color_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr" = "Successor Rep."
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr" = "#af8dc3"
    )
  ) +
  theme(legend.position = "bottom") +
  ggtitle("Held-out trials: log-likelihoods")

plot_heldout_likelihoods

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "heldout_likelihoods.pdf"),
    plot = plot_heldout_likelihoods,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}

Although the likelihoods give us a nice quantitative metric, we may also be interested in knowing how well the SR predicts subjects’ choices just in terms of accuracy. In the below analysis, we’ll use mixed-effects logistic regression to test whether subjects significantly choose the SR-preferred option. Note that we’re removing all of the trials where the SR is indifferent to the two options, as those trials lead us to overestimate the model’s predicted accuracy (i.e., because the subject is always right on those trials). Note also that we’re iteratively re-parameterizing the model with a different reference category each time, so that we can test whether model accuracy is significantly above chance at each distance. Finally, note that for the purpose of statistical testing, we’re looking at the two main measurements: before rest (day 1), and after overnight rest (day 2).

stats_heldout_accuracy_dist2 <- heldout_likelihoods %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  filter(!is.na(sr_prefers)) %>%
  mutate(
    p_sub_chooses_sr_preference = sub_choice == sr_prefers,
    sub_id = str_c(study, ", s", sub_id)
  ) %>%
  glmmTMB(
    p_sub_chooses_sr_preference ~ shortest_path +
      (1 | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_heldout_accuracy_dist3 <- heldout_likelihoods %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  filter(!is.na(sr_prefers)) %>%
  mutate(
    p_sub_chooses_sr_preference = sub_choice == sr_prefers,
    sub_id = str_c(study, ", s", sub_id),
    shortest_path = fct_relevel(shortest_path, "3")
  ) %>%
  glmmTMB(
    p_sub_chooses_sr_preference ~ shortest_path +
      (1 | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_heldout_accuracy_dist4 <- heldout_likelihoods %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  filter(!is.na(sr_prefers)) %>%
  mutate(
    p_sub_chooses_sr_preference = sub_choice == sr_prefers,
    sub_id = str_c(study, ", s", sub_id),
    shortest_path = fct_relevel(shortest_path, "4")
  ) %>%
  glmmTMB(
    p_sub_chooses_sr_preference ~ shortest_path +
      (1 | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist2" = stats_heldout_accuracy_dist2,
    "dist3" = stats_heldout_accuracy_dist3,
    "dist4" = stats_heldout_accuracy_dist4
  ),
  .f = ~tidy(.x) %>% select(-component),
  .id = "ref_cat"
) %>%
  kable_custom(
    "Held-out trials: SR model accuracy",
    grouping_var = ref_cat
  )
Held-out trials: SR model accuracy
effect group term estimate std.error statistic p.value
dist2
fixed NA (Intercept) 0.002 0.058 0.036 0.971
fixed NA shortest_path3 0.307 0.085 3.618 0.000
fixed NA shortest_path4 0.196 0.098 1.993 0.046
ran_pars sub_id sd__(Intercept) 0.289 NA NA NA
ran_pars study sd__(Intercept) 0.000 NA NA NA
dist3
fixed NA (Intercept) 0.309 0.071 4.378 0.000
fixed NA shortest_path2 -0.307 0.085 -3.618 0.000
fixed NA shortest_path4 -0.111 0.106 -1.047 0.295
ran_pars sub_id sd__(Intercept) 0.289 NA NA NA
ran_pars study sd__(Intercept) 0.000 NA NA NA
dist4
fixed NA (Intercept) 0.198 0.086 2.294 0.022
fixed NA shortest_path2 -0.196 0.098 -1.993 0.046
fixed NA shortest_path3 0.111 0.106 1.047 0.295
ran_pars sub_id sd__(Intercept) 0.289 NA NA NA
ran_pars study sd__(Intercept) 0.000 NA NA NA

In the plot below, we can see that there’s quite a bit of variability across subjects (black lines and datapoints), but also that at the group level, the SR achieves above-chance accuracy for the held-out distance-3 and distance-4 trials (red).

predict_heldout_accuracy <- tibble(
  shortest_path = factor(2:4), sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_heldout_accuracy_dist2)

plot_heldout_accuracy <- heldout_likelihoods %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  filter(!is.na(sr_prefers)) %>%
  group_by(study, sub_id, shortest_path) %>%
  summarise(
    p_sub_chooses_sr_preference = mean(sub_choice == sr_prefers),
    .groups = "drop"
  ) %>%
  ggplot(aes(x=shortest_path, y=p_sub_chooses_sr_preference)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_line(aes(group = interaction(study, sub_id)), alpha = 0.1) +
  geom_point(alpha = 0.1) +
  geom_pointrange(
    aes(x = shortest_path, y = fit, ymin = fit - se.fit, ymax = fit + se.fit),
    data = predict_heldout_accuracy, inherit.aes = FALSE,
    # fatten = 1, size = 1,
    linewidth = 1, color = "red"
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "p(Human chooses SR-preferred Source)",
    labels = scales::percent,
    breaks = seq(0, 1, .25),
    expand = expansion(mult = c(0.1, 0.15))
  ) +
  ggtitle("Held-out trials: SR model accuracy")

plot_heldout_accuracy

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "heldout_accuracy.pdf"),
    plot = plot_heldout_accuracy,
    width = 4, height = 5,
    units = "in", dpi = 300
  )
}

McFadden’s pseudo R-squared

Finally, it can be useful to get a sense for each model’s “goodness-of-fit” by computing the ratio of its likelihood to the likelihood of a null model (i.e., McFadden’s R-squared). Here, the null model is an agent that chooses completely at random on every trial.

mcfadden_r2 <- aicc %>%
  filter(model != "sr_delta_rule") %>%
  mutate(loglik = -neg_loglik) %>%
  select(study, sub_id, measurement_id, model, model_loglik = loglik) %>%
  mutate(
    # Null model is chance-level choice on every trial, so this ends up being
    # exactly equal to the usual formulation
    null_loglik = nav_trials %>%
      filter(two_correct_options == FALSE) %>%
      nrow() %>%
      {. * log(0.5)}
  ) %>%
  mutate(mcfadden_r2 = 1 - (model_loglik / null_loglik))

mcfadden_r2 %>%
  group_by(study, measurement_id, model) %>%
  summarise(Mean = mean(mcfadden_r2), .groups = "drop") %>%
  pivot_wider(names_from = model, values_from = Mean) %>%
  kable_custom("Mean McFadden's R2", grouping_var = study)
Mean McFadden’s R2
measurement_id bfs_backward bfs_forward ideal_obs sr_analytic
Study 1
D1 0.185 0.171 0.129 0.200
Study 2
D1 0.252 0.245 0.221 0.275
D2 0.310 0.291 0.293 0.335
Study 3
D1 0.290 0.273 0.266 0.311
D1b 0.281 0.271 0.276 0.306
D2 0.338 0.318 0.335 0.374
mcfadden_r2 %>%
  group_by(study, measurement_id, model) %>%
  summarise(Median = median(mcfadden_r2), .groups = "drop") %>%
  pivot_wider(names_from = model, values_from = Median) %>%
  kable_custom("Median McFadden's R2", grouping_var = study)
Median McFadden’s R2
measurement_id bfs_backward bfs_forward ideal_obs sr_analytic
Study 1
D1 0.114 0.124 0.063 0.151
Study 2
D1 0.184 0.172 0.138 0.195
D2 0.226 0.221 0.188 0.245
Study 3
D1 0.198 0.171 0.134 0.194
D1b 0.166 0.155 0.161 0.196
D2 0.237 0.218 0.209 0.268

We see in the plot that, generally, all models are doing better than the null model, and also that the SR consistently seems to have the best likelihood ratio.

plot_mcfadden_r2 <- mcfadden_r2 %>%
  ggplot(aes(x=measurement_id, y=mcfadden_r2, color=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_point(
    alpha = 0.1,
    position = position_jitterdodge(
      jitter.width = 0.1, jitter.height = 0, dodge.width = 0.75, seed = 1
    ),
    show.legend = FALSE
  ) +
  stat_summary(
    geom = "crossbar", fun = median, position = position_dodge(width = 0.75)
  ) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(name = bquote(~"McFadden's"~R^2)) +
  scale_color_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(legend.position = "bottom") +
  ggtitle("Model goodness-of-fit")

plot_mcfadden_r2

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "mcfadden_r2.pdf"),
    plot = plot_mcfadden_r2,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}

Figures for supplement

plot_model_comparison_for_supp <- wrap_plots(
  plot_akaike_group,
  plot_best_fitting_model_prop,
  plot_pxp,
  guides = "collect", nrow = 1
) +
  plot_annotation(
    title = "Model comparison",
    tag_levels = "A", tag_suffix = ".",
    theme = theme(plot.title = element_text(hjust = 0.5))
  ) &
  theme(legend.position = "bottom")

plot_model_comparison_for_supp

if (knitting) {
  ggsave(
    filename = here("figures", "supp_model_comparison.pdf"),
    plot = plot_model_comparison_for_supp,
    width = 16, height = 4,
    units = "in", dpi = 300
  )
}
plot_ppc_for_supp <- wrap_plots(
  plot_ppc_day2 + ggtitle("After overnight rest (Studies 2-3)"),
  plot_ppc_day1b + ggtitle("After awake rest (Study 3)"),
  ncol = 1
) +
  plot_annotation(
    title = "Posterior predictive check",
    tag_levels = "A", tag_suffix = ".",
    theme = theme(plot.title = element_text(hjust = 0.5))
  ) &
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent,
    limits = c(.25, 1),
    breaks = seq(.25, 1, .25)
  )
## Scale for y is already present.
## Adding another scale for y, which will replace the existing scale.
## Scale for y is already present.
## Adding another scale for y, which will replace the existing scale.
plot_ppc_for_supp

if (knitting) {
  ggsave(
    filename = here("figures", "supp_ppc.pdf"),
    plot = plot_ppc_for_supp,
    width = 8, height = 6,
    units = "in", dpi = 300
  )
}

There are some plots that we want to use as-is, so we’ll save a redundant copy in the figures folder.

if (knitting) {
  ggsave(
    filename = here("figures", "supp_param_estimates.pdf"),
    plot = plot_params_all,
    width = 12, height = 6,
    units = "in", dpi = 300
  )
  
  ggsave(
    filename = here("figures", "supp_akaike_weights_per_subject.pdf"),
    plot = plot_akaike_individual,
    width = 8, height = 10,
    units = "in", dpi = 300
  )
  
  ggsave(
    filename = here("figures", "supp_heldout_likelihoods.pdf"),
    plot = plot_heldout_likelihoods,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
  
  ggsave(
    filename = here("figures", "supp_heldout_accuracy.pdf"),
    plot = plot_heldout_accuracy,
    width = 4, height = 5,
    units = "in", dpi = 300
  )
  
  ggsave(
    filename = here("figures", "supp_mcfadden_r2.pdf"),
    plot = plot_mcfadden_r2,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
## `geom_line()`: Each group consists of only one observation.
## ℹ Do you need to adjust the group aesthetic?
---
title: "Model comparison"
output:
  html_document:
    code_download: true
    code_folding: hide
    toc: true
    toc_float:
      collapsed: true
---

# Setup

```{r libraries}
workflow_name <- "netnav_06_model_comparison"

library(tidyverse)
library(here)
library(patchwork)

library(glmmTMB)
library(broom.mixed)

source(here("code", "utils", "modeling_utils.R"))
source(here("code", "utils", "representation_utils.R"))
source(here("code", "utils", "bayesian_model_selection.R"))

source(here("code", "utils", "ggplot_themes.R"))
source(here("code", "utils", "kable_utils.R"))
source(here("code", "utils", "unicode_greek.R"))

knitting <- knitr::is_html_output()

create_path <- function(this_path) {
  if (!dir.exists(this_path)) {
    dir.create(this_path, recursive = TRUE)
  }
}

predict_glmmTMB <- function(make_predictions_for, model_object) {
  make_predictions_for %>%
    bind_cols(
      predict(
        object = model_object,
        newdata = .,
        re.form = NA, allow.new.levels = TRUE, se.fit = TRUE, type = "response"
      )
    )
}

if (knitting) {
  here("outputs", workflow_name) %>%
    create_path()
  
  here("figures") %>%
    create_path()
}
```

```{r load-behav-data}
nav_study1 <- here("data", "clean_data", "study1_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    # two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 1",
    measurement_id = str_c("D", measurement_id),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt,
    two_correct_options,
    opt1_distance = dist_opt1,
    opt2_distance = dist_opt2
  )

nav_study2 <- here("data", "clean_data", "study2_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    # two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 2",
    measurement_id = case_when(
      network == "learned" ~ str_c("D", measurement_id),
      network == "reevaluated" ~ "D2b"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt,
    two_correct_options,
    opt1_distance = dist_opt1,
    opt2_distance = dist_opt2
  )

nav_study3 <- here("data", "clean_data", "study3_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    # two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(
    study = "Study 3",
    measurement_id = case_when(
      network == "reevaluated" ~ "D2b",
      measurement_id == 1 ~ "D1",
      measurement_id == 2 ~ "D1b",
      measurement_id == 3 ~ "D2"
    ),
    shortest_path = factor(shortest_path_given_opts)
  ) %>%
  select(
    study, sub_id, measurement_id, shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice, sub_choice,
    correct, rt,
    two_correct_options,
    opt1_distance = dist_opt1,
    opt2_distance = dist_opt2
  )
```

```{r load-data-for-ppc}
bfs_backward_sims <- here(
  "data", "bfs_sims", "bfs_sims_learned_backward.csv"
) %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    shortest_path_given_opts == shortest_path_given_start_end,
    # two_correct_options == FALSE
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  select(-starts_with("shortest_path_given")) %>%
  group_by(
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    correct_choice, two_correct_options
  ) %>%
  summarise(
    p_bfs_correct = mean(bfs_choice == correct_choice),
    p_bfs_chooses_opt1 = mean(bfs_choice == opt1_id),
    bfs_visits = mean(bfs_n_visits_total),
    .groups = "drop"
  )

bfs_forward_sims <- here("data", "bfs_sims", "bfs_sims_learned_forward.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    shortest_path_given_opts == shortest_path_given_start_end,
    # two_correct_options == FALSE
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  select(-starts_with("shortest_path_given")) %>%
  group_by(
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    correct_choice, two_correct_options
  ) %>%
  summarise(
    p_bfs_correct = mean(bfs_choice == correct_choice),
    p_bfs_chooses_opt1 = mean(bfs_choice == opt1_id),
    bfs_visits = mean(bfs_n_visits_total),
    .groups = "drop"
  )

nav_trials <- here("data", "clean_data", "study1_message_passing.csv") %>%
  read_csv(show_col_types = FALSE) %>%
  filter(
    # two_correct_options == FALSE,
    shortest_path_given_opts == shortest_path_given_start_end
  ) %>%
  mutate(shortest_path = factor(shortest_path_given_opts)) %>%
  filter(sub_id == 1) %>%
  select(
    shortest_path,
    startpoint_id, endpoint_id,
    opt1_id, opt2_id,
    correct_choice,
    opt1_distance = dist_opt1,
    opt2_distance = dist_opt2,
    two_correct_options
  ) %>%
  arrange(shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id) %>%
  # Replace undefined distances (corresponding to impossible options)
  # so that the softmax gets non-NA inputs; we assume that impossible
  # options are just as bad as the longest distance found in this set
  # of trials, i.e., a distance of 8
  mutate(across(c(opt1_distance, opt2_distance), ~replace_na(.x, 8)))

adjlist <- here("data", "clean_data", "adjlist_learned.csv") %>%
  read_csv(show_col_types = FALSE)

transmat <- adjlist %>%
  group_by(from) %>%
  mutate(edge = edge / sum(edge)) %>%
  ungroup() %>%
  pivot_wider(names_from = to, values_from = edge) %>%
  column_to_rownames("from") %>%
  as.matrix()
```

```{r load-params}
load_params_from_scratch <- FALSE

if (load_params_from_scratch == TRUE) {
  params <- here("data", "param_fits") %>%
    fs::dir_ls(
      recurse = 1,
      regexp = str_c(
        "_",
        "(bfs_(backward|forward)|",
        "ideal_obs|",
        "sr_(analytic|delta_rule))_",
        "(.)+\\.csv"
      )
    ) %>%
    map_dfr(
      .f = ~read_csv(.x, show_col_types = FALSE),
      .id = "filename"
    ) %>%
    mutate(
      # Recover model ID
      model = str_extract(
        filename,
        str_c(
          "_",
          "(bfs_(backward|forward)|",
          "ideal_obs|",
          "sr_(analytic|delta_rule))_"
        )
      ),
      model = str_sub(model, 2, -2),
      # Recover study ID
      study = str_extract(filename, "study[[:digit:]]"),
      study = str_replace(study, "study", "Study "),
      # Recover subject ID
      sub_id = str_extract(filename, "sub_[[:digit:]]+"),
      sub_id = str_remove(sub_id, "sub_"),
      sub_id = as.integer(sub_id),
      # Recover measurement ID
      measurement_id = str_extract(filename, "_D[[:digit:]]b?"),
      measurement_id = str_remove(measurement_id, "_"),
      # Get parameter values
      param_value = if_else(
        is.na(param_value_human_readable),
        param_value,
        param_value_human_readable
      )
    ) %>%
    # Find best-fitting optimization run
    filter(convergence == "converged") %>%
    group_by(model, study, sub_id, measurement_id) %>%
    slice_min(optim_value, n = 1) %>%
    ungroup() %>%
    # Some subjects may have had multiple "best" optimization runs
    # In that case, just go with whichever "best" run was estimated first
    group_by(model, study, sub_id, measurement_id) %>%
    slice_min(optimizer_run, n = 1) %>%
    ungroup() %>%
    # Clean up
    select(
      model, study, sub_id, measurement_id,
      param_name, param_value,
      neg_loglik = optim_value
    ) %>%
    arrange(model, study, sub_id, measurement_id, param_name)
  
  here("data", "param_fits", "clean_params") %>%
    create_path()
  
  params %>%
    write_csv(
      file = here("data", "param_fits", "clean_params", "clean_param_fits.csv")
    )
}

params <- here("data", "param_fits", "clean_params", "clean_param_fits.csv") %>%
  read_csv(show_col_types = FALSE)
```


# AICc

As our metric of log-evidence, we'll use AICc, i.e., AIC corrected for a relatively small N.

```{r calc-aicc}
aicc <- params %>%
  select(study, sub_id, measurement_id, model, neg_loglik) %>%
  distinct() %>%
  arrange(study, sub_id, measurement_id) %>%
  mutate(
    n_params = if_else(model == "sr", 2, 1),
    n_datapoints = 115,
    aic = (-2 * -neg_loglik) + (2 * n_params),
    aicc = aic + (
      (2 * n_params * (n_params + 1)) / (n_datapoints - n_params - 1)
    )
  )
```


# SR analytic vs delta-rule

Before we go any further, let's just get one targeted comparison out of the way. In earlier scripts, we saw that SR matrices can be constructed using a closed-form analytic solution, or a delta-rule updating mechanism. These different implementations can, in principle, end up making very different predictions. Here, we'll see whether there's evidence that one implementation fits better than the other.

Below, we're directly comparing the AICc of the analytic vs delta-rule implementations. Each datapoint is one subject, and the lines connect a subject's AICc from one implementation to the other. The red bars reflect the means. The lines are remarkably flat, indicating that there is functionally no real difference in the model goodness-of-fit.

```{r sr-comparison-aicc}
plot_sr_comparison_aicc <- aicc %>%
  filter(str_detect(model, "sr_")) %>%
  mutate(facet_label = str_c(study, ", ", measurement_id)) %>%
  ggplot(aes(x=model, y=aicc)) +
  theme_custom() +
  facet_wrap(~facet_label, scales = "free_x") +
  geom_point(alpha = 0.5) +
  geom_line(aes(group = sub_id), alpha = 0.2) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(
    name = "SR implementation",
    labels = c("sr_analytic"="Analytic", "sr_delta_rule"="Delta-rule")
  ) +
  ylab("AICc") +
  ggtitle("AICc comparison of SR implementations")

plot_sr_comparison_aicc

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "sr_aicc_analytic_vs_delta_rule.pdf"
    ),
    plot = plot_sr_comparison_aicc,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
```

We're not interested in doing any sort of hypothesis testing here, but for the purpose of doing model selection, we do want to know whether choosing one implementation over the other might result in forming different conclusions about the parameter fits.

Below, the bars reflect medians. We can see that, by-and-large, the two implementations result in very similar estimates.

```{r sr-comparison-gamma}
plot_sr_comparison_gamma <- params %>%
  filter(str_detect(model, "sr_")) %>%
  filter(param_name == "sr_gamma") %>%
  select(model, study, sub_id, measurement_id, sr_gamma = param_value) %>%
  ggplot(aes(x=measurement_id, y=sr_gamma, color=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_point(
    alpha = 0.5, position = position_dodge(width = 0.25), show.legend = FALSE
  ) +
  geom_line(aes(group = interaction(model, sub_id)), alpha = 0.2) +
  stat_summary(geom = "crossbar", fun = median) +
  xlab("Measurement") +
  ylab("SR gamma") +
  scale_color_manual(
    name = "SR implementation",
    labels = c("sr_analytic"="Analytic", "sr_delta_rule"="Delta-rule"),
    values = c("sr_analytic"="#ca0020", "sr_delta_rule"="#0571b0")
  ) +
  ggtitle("Estimates of SR gamma: Analytic vs delta-rule") +
  theme(legend.position = "bottom")

plot_sr_comparison_gamma

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "sr_gamma_analytic_vs_delta_rule.pdf"
    ),
    plot = plot_sr_comparison_gamma,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
```

In the modeling, the goal was to test changes in the gamma parameter, assuming an asymptotic representation. These results suggest that there is nothing lost by using the analytic closed-form implementation, so we'll stick with that from here onwards.


# Description of parameter fits

```{r descriptive-param-fits}
params %>%
  group_by(model, study, measurement_id, param_name) %>%
  summarise(
    param_mean = mean(param_value),
    param_median = median(param_value),
    .groups = "drop"
  ) %>%
  kable_custom(
    captions = "Descriptive stats: parameter fits",
    grouping_var = model
  )
```

```{r plot-param-fits}
plot_params_bfs_backward <- params %>%
  filter(model == "bfs_backward") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=search_threshold)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Search threshold") +
  ggtitle("BFS-backward")

plot_params_bfs_forward <- params %>%
  filter(model == "bfs_forward") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=search_threshold)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Search threshold") +
  ggtitle("BFS-forward")

plot_params_ideal_obs <- params %>%
  filter(model == "ideal_obs") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=softmax_temperature)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Inverse temperature") +
  ggtitle("Ideal observer")

plot_params_sr_temp <- params %>%
  filter(model == "sr_analytic") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=softmax_temperature)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Inverse temperature") +
  ggtitle("Successor Representation") +
  coord_cartesian(ylim = c(0, 3000))

plot_params_sr_gamma <- params %>%
  filter(model == "sr_analytic") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  ggplot(aes(x=measurement_id, y=sr_gamma)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  stat_summary(geom = "crossbar", fun = median, color = "blue") +
  geom_point(
    alpha = 0.1,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  geom_line(
    aes(group = sub_id), alpha = 0.25,
    position = position_jitter(width = 0.1, height = 0, seed = 1)
  ) +
  xlab("Measurement") +
  ylab("Gamma") +
  ggtitle("Successor Representation")

plot_params_all <- (
  (plot_params_bfs_backward | plot_params_bfs_forward | plot_params_ideal_obs) /
    (plot_params_sr_temp | plot_params_sr_gamma)
) +
  plot_annotation(
    title = "Estimated parameters",
    tag_levels = "A", tag_suffix = ".",
    theme = theme(plot.title = element_text(hjust = 0.5))
  )

plot_params_all

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "param_estimates.pdf"),
    plot = plot_params_all,
    width = 12, height = 6,
    units = "in", dpi = 300
  )
}
```


# Akaike weights

We'll later use protected exceedance probabilities (PXP) to do formal inference to test whether a particular model provides a significantly better group-level fit than other models. But first, we do want to acknowledge that there's likely to be some individual differences in how well a particular model fits each subject. To get a sense for this, we'll use Akaike weights, which provide the probability that a particular model is the "best" given the data and the set of candidate models.

```{r calc-akaike-weights}
akaike_weights <- aicc %>%
  filter(model != "sr_delta_rule") %>%
  group_by(study, measurement_id, sub_id) %>%
  mutate(
    relative_likelihood = exp(-1/2 * (aicc - min(aicc))),
    akaike_weight = relative_likelihood / sum(relative_likelihood),
    evidence_ratio = max(akaike_weight) / akaike_weight
  ) %>%
  ungroup() %>%
  arrange(study, measurement_id, sub_id, evidence_ratio)
```

We can first average over all subjects' Akaike weights to get a sense for what the "best-fitting" model is across subjects. This suggests that the SR consistently comes out on top, followed pretty consistently by BFS-backward.

```{r plot-group-akaike-weights}
plot_akaike_group <- akaike_weights %>%
  group_by(study, measurement_id, model) %>%
  summarise(akaike_weight = mean(akaike_weight), .groups = "drop") %>%
  mutate(text = round(akaike_weight, 2)) %>%
  ggplot(aes(x=measurement_id, y=akaike_weight, fill=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_col() +
  geom_text(aes(label = text), position = position_stack(vjust = 0.5)) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(
    name = NULL,
    expand = expansion(mult = c(0.01, 0.01))
  ) +
  scale_fill_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(
    legend.position = "bottom",
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  ggtitle("Akaike weights")

plot_akaike_group

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "akaike_weights_group.pdf"),
    plot = plot_akaike_group,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
```

We can break this out and plot each individual subject's Akaike weights.

```{r plot-individual-akaike-weights}
plot_akaike_individual <- akaike_weights %>%
  mutate(
    sub_id = factor(sub_id),
    measurement_id = case_when(
      measurement_id == "D1" ~ "before rest",
      measurement_id == "D1b" ~ "after awake rest",
      measurement_id == "D2" ~ "after overnight rest"
    ),
    study = str_c(study, ", ", measurement_id),
    study = fct_relevel(
      study,
      "Study 1, before rest",
      "Study 2, before rest",
      "Study 2, after overnight rest",
      "Study 3, before rest",
      "Study 3, after awake rest",
      "Study 3, after overnight rest"
    )
  ) %>%
  ggplot(aes(x=sub_id, y=akaike_weight, fill=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x", ncol = 1) +
  geom_col() +
  scale_x_discrete(name = "Subject ID") +
  scale_y_continuous(
    name = NULL,
    expand = expansion(mult = c(0.01, 0.01))
  ) +
  scale_fill_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(
    legend.position = "bottom",
    axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1),
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  ggtitle("Akaike weights")

plot_akaike_individual

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "akaike_weights_per_subject.pdf"
    ),
    plot = plot_akaike_individual,
    width = 8, height = 10,
    units = "in", dpi = 300
  )
}
```

Akaike weights provide a nice goodness-of-fit metric that respects the probabilistic aspect of model comparison, and can do so at both the group- and individual-level. However, model selection requires us to ultimately make a discrete choice. If we made per-subject decisions based simply by choosing the single best-fitting model, what proportion of subjects are best-fit by each model? We see that the pattern of results basically mirrors what we'd seen in the Akaike weights, such that the SR is the best-fitting model for the majority of subjects, followed by BFS-backward.

```{r plot-group-prop-best-fits}
best_fitting_model_per_sub <- akaike_weights %>%
  group_by(study, measurement_id, sub_id) %>%
  slice_max(akaike_weight) %>%
  ungroup() %>%
  select(study, measurement_id, sub_id, best_fitting_model = model)

plot_best_fitting_model_prop <- best_fitting_model_per_sub %>%
  count(study, measurement_id, best_fitting_model) %>%
  group_by(study, measurement_id) %>%
  mutate(
    p = n / sum(n),
    text = str_c(round(p, 2) * 100, "%")
  ) %>%
  ggplot(aes(x=measurement_id, y=p, fill=best_fitting_model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_col() +
  geom_text(aes(label = text), position = position_stack(vjust = 0.5)) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(
    name = NULL,
    expand = expansion(mult = c(0.01, 0.01))
  ) +
  scale_fill_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(
    legend.position = "bottom",
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  ggtitle("Proportion of subjects best fit by each model")

plot_best_fitting_model_prop

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "prop_subjects_best_fit.pdf"
    ),
    plot = plot_best_fitting_model_prop,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
```


# Protected exceedance probabilities

Protected exceedance probabilities provide a formal test of a model’s group-level fit compared to other candidate models. To run this analysis, we'll adapt software originally written by Matteo Lisi (https://github.com/mattelisi/bmsR). We'll use AICc as our metric of log-evidence.

```{r calc-pxp}
pxp_results <- aicc %>%
  filter(model != "sr_delta_rule") %>%
  # In PXP, more is more. AICc, in contrast, is based off neg-LL, and so is
  # interpreted as "smaller is better". So, do a sign flip.
  mutate(aicc = -aicc) %>%
  select(study, measurement_id, sub_id, model, aicc) %>%
  # Compute PXP for each study/measurement
  pivot_wider(names_from = model, values_from = aicc) %>%
  select(-sub_id) %>%
  group_by(study, measurement_id) %>%
  nest() %>%
  mutate(
    test = map(
      .x = data,
      .f = ~bayesian_model_selection(.x)
    )
  ) %>%
  ungroup() %>%
  unnest(test) %>%
  select(-data)
```

In the results, it's clear that the SR comes out on top, and by a large margin.

```{r pxp-results}
pxp_results %>%
  mutate(
    measurement_id = case_when(
      measurement_id == "D1" ~ "before rest",
      measurement_id == "D1b" ~ "after awake rest",
      measurement_id == "D2" ~ "after overnight rest"
    ),
    study = str_c(study, ", ", measurement_id),
    study = fct_relevel(
      study,
      "Study 1, before rest",
      "Study 2, before rest",
      "Study 2, after overnight rest",
      "Study 3, before rest",
      "Study 3, after awake rest",
      "Study 3, after overnight rest"
    )
  ) %>%
  select(-measurement_id) %>%
  arrange(study, desc(pxp)) %>%
  kable_custom("PXP results", grouping_var = study)
```

```{r plot-pxp}
plot_pxp <- pxp_results %>%
  mutate(
    text = round(pxp, 2) * 100,
    text = if_else(model != "sr_analytic", "", str_c(text, "%"))
  ) %>%
  ggplot(aes(x=measurement_id, y=pxp, fill=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_col() +
  geom_text(aes(label = text), position = position_stack(vjust = 0.5)) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(
    name = NULL,
    expand = expansion(mult = c(0.01, 0.01))
  ) +
  scale_fill_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(
    legend.position = "bottom",
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  ggtitle("Protected exceedance probabilities")

plot_pxp

if (knitting) {
  ggsave(
    filename = here(
      "outputs", workflow_name,
      "pxp.pdf"
    ),
    plot = plot_pxp,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
```


# Posterior predictive check

## Simulate predicted behaviors

It's nice to see consistency in the *quantitative* model comparison and selection, but we'd also like to see how well human behaviors are *qualitatively* described by our models. To do this, we'll simulate subjects' predicted behaviors given their model parameters.

```{r ppc-bfs-backward}
ppc_bfs_backward <- expand_grid(
  # List of all trials from BFS simulation for each subject/measurement
  params %>%
    select(study, sub_id, measurement_id) %>%
    distinct(),
  bfs_backward_sims
) %>%
  # Add subject-specific parameters
  left_join(
    params %>%
      filter(model == "bfs_backward") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, search_threshold),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # What's the probability of *completing* BFS-online all the way through?
  rowwise() %>%
  mutate(
    p_complete_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh BFS predictions accordingly
  mutate(
    p_give_up = 1 - p_complete_bfs,
    model_p_correct = (
      (p_complete_bfs * p_bfs_correct) + (p_give_up * 1/2)
    )
  ) %>%
  # Add subjects' actual choices
  left_join(
    bind_rows(nav_study1, nav_study2, nav_study3) %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        sub_choice, sub_correct = correct, sub_rt = rt
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  )
```

```{r ppc-bfs-forward}
ppc_bfs_forward <- expand_grid(
  # List of all trials from BFS simulation for each subject/measurement
  params %>%
    select(study, sub_id, measurement_id) %>%
    distinct(),
  bfs_forward_sims
) %>%
  # Add subject-specific parameters
  left_join(
    params %>%
      filter(model == "bfs_forward") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, search_threshold),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # What's the probability of *completing* BFS-online all the way through?
  rowwise() %>%
  mutate(
    p_complete_bfs = softmax(
      option_values = c(search_threshold, bfs_visits),
      option_chosen = 1,
      temperature = 1
    )
  ) %>%
  ungroup() %>%
  # Weigh BFS predictions accordingly
  mutate(
    p_give_up = 1 - p_complete_bfs,
    model_p_correct = (
      (p_complete_bfs * p_bfs_correct) + (p_give_up * 1/2)
    )
  ) %>%
  # Add subjects' actual choices
  left_join(
    bind_rows(nav_study1, nav_study2, nav_study3) %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        sub_choice, sub_correct = correct, sub_rt = rt
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  )
```

```{r ppc-ideal-obs}
ppc_ideal_obs <- expand_grid(
  # List of all trials for each subject/measurement
  params %>%
    select(study, sub_id, measurement_id) %>%
    distinct(),
  nav_trials
) %>%
  # Add subject-specific parameters
  left_join(
    params %>%
      filter(model == "ideal_obs") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, softmax_temperature),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # Model predictions
  rowwise() %>%
  mutate(
    model_p_correct = softmax(
      option_values = c(opt1_distance, opt2_distance),
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE
    )
  ) %>%
  ungroup() %>%
  # Add subjects' actual choices
  left_join(
    bind_rows(nav_study1, nav_study2, nav_study3) %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        sub_choice, sub_correct = correct, sub_rt = rt
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  )
```

```{r ppc-sr}
ppc_sr_representation <- params %>%
  filter(model == "sr_analytic") %>%
  pivot_wider(names_from = param_name, values_from = param_value) %>%
  select(study, sub_id, measurement_id, sr_gamma) %>%
  rowwise() %>%
  mutate(
    predicted_sr = map(
      .x = sr_gamma,
      .f = ~build_successor_analytically(
        transmat, successor_horizon = .x, normalize = TRUE
      )
    )
  ) %>%
  ungroup() %>%
  select(study, sub_id, measurement_id, predicted_sr) %>%
  unnest(predicted_sr)

ppc_sr_navigation <- expand_grid(
  # List of all trials for each subject/measurement
  params %>%
    select(study, sub_id, measurement_id) %>%
    distinct(),
  nav_trials
) %>%
  # Add subject-specific parameters
  left_join(
    params %>%
      filter(model == "sr_analytic") %>%
      pivot_wider(names_from = param_name, values_from = param_value) %>%
      select(study, sub_id, measurement_id, softmax_temperature),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # Add SR predicted representation
  left_join(
    ppc_sr_representation %>%
      select(
        study, sub_id, measurement_id,
        endpoint_id = to, opt1_id = from, opt1_sr = sr_value
      ),
    by = join_by(study, sub_id, measurement_id, endpoint_id, opt1_id)
  ) %>%
  left_join(
    ppc_sr_representation %>%
      select(
        study, sub_id, measurement_id,
        endpoint_id = to, opt2_id = from, opt2_sr = sr_value
      ),
    by = join_by(study, sub_id, measurement_id, endpoint_id, opt2_id)
  ) %>%
  # Model navigation predictions
  rowwise() %>%
  mutate(
    model_p_correct = softmax(
      option_values = c(opt1_sr, opt2_sr),
      option_chosen = if_else(correct_choice == opt1_id, 1, 2),
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE
    )
  ) %>%
  ungroup() %>%
  # Add subjects' actual choices
  left_join(
    bind_rows(nav_study1, nav_study2, nav_study3) %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        sub_choice, sub_correct = correct, sub_rt = rt
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  )
```

## Plot PPC

```{r create-ppc-for-plotting}
ppc_for_plotting <- params %>%
  select(study, sub_id, measurement_id) %>%
  distinct() %>%
  # Add human accuracy + BFS-backward accuracy
  left_join(
    ppc_bfs_backward %>%
      filter(two_correct_options == FALSE) %>%
      group_by(study, sub_id, measurement_id, shortest_path) %>%
      summarise(
        human = mean(sub_correct),
        bfs_backward = mean(model_p_correct),
        .groups = "drop"
      ),
    by = join_by(study, sub_id, measurement_id)
  ) %>%
  # Add BFS-forward accuracy
  left_join(
    ppc_bfs_forward %>%
      filter(two_correct_options == FALSE) %>%
      group_by(study, sub_id, measurement_id, shortest_path) %>%
      summarise(
        bfs_forward = mean(model_p_correct),
        .groups = "drop"
      ),
    by = join_by(study, sub_id, measurement_id, shortest_path)
  ) %>%
  # Add ideal observer accuracy
  left_join(
    ppc_ideal_obs %>%
      filter(two_correct_options == FALSE) %>%
      group_by(study, sub_id, measurement_id, shortest_path) %>%
      summarise(
        ideal_obs = mean(model_p_correct),
        .groups = "drop"
      ),
    by = join_by(study, sub_id, measurement_id, shortest_path)
  ) %>%
  # Add SR accuracy
  left_join(
    ppc_sr_navigation %>%
      filter(two_correct_options == FALSE) %>%
      group_by(study, sub_id, measurement_id, shortest_path) %>%
      summarise(
        sr = mean(model_p_correct),
        .groups = "drop"
      ),
    by = join_by(study, sub_id, measurement_id, shortest_path)
  ) %>%
  # For plotting aesthetics
  pivot_longer(human:sr, names_to = "agent", values_to = "accuracy") %>%
  mutate(
    agent = case_when(
      agent == "human" ~ "Human",
      agent == "bfs_backward" ~ "BFS-backward",
      agent == "bfs_forward" ~ "BFS-forward",
      agent == "ideal_obs" ~ "Ideal observer",
      agent == "sr" ~ "Successor Rep."
    ),
    agent = fct_relevel(agent, "Human", "Successor Rep.")
  )
```

Let's look at the main set of trials and compare human performance against the models on Day 1 (i.e., before rest).

```{r ppc-day1}
plot_ppc_day1 <- ppc_for_plotting %>%
  filter(measurement_id == "D1") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  facet_wrap(~agent, nrow = 1) +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "blue") +
  geom_line(aes(group = interaction(study, sub_id)), alpha = 0.1) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  ggtitle("Posterior predictive check: Before rest")
  
plot_ppc_day1

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "ppc_day1.pdf"),
    plot = plot_ppc_day1,
    width = 6, height = 3,
    units = "in", dpi = 300
  )
}
```

We'll do the same now for Day 2 (i.e., after overnight rest).

```{r ppc-day2}
plot_ppc_day2 <- ppc_for_plotting %>%
  filter(measurement_id == "D2") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  facet_wrap(~agent, nrow = 1) +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "blue") +
  geom_line(aes(group = interaction(study, sub_id)), alpha = 0.1) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  ggtitle("Posterior predictive check: After overnight rest")
  
plot_ppc_day2

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "ppc_day2.pdf"),
    plot = plot_ppc_day2,
    width = 6, height = 3,
    units = "in", dpi = 300
  )
}
```

And finally, for Day 1b (i.e., after awake rest on Day 1), a measurement that was only made in Study 3.

```{r ppc-day1b}
plot_ppc_day1b <- ppc_for_plotting %>%
  filter(measurement_id == "D1b") %>%
  ggplot(aes(x=shortest_path, y=accuracy)) +
  theme_custom() +
  facet_wrap(~agent, nrow = 1) +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "blue") +
  geom_line(aes(group = interaction(study, sub_id)), alpha = 0.1) +
  stat_summary(geom = "crossbar", fun = mean, color = "red") +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(name = "Accuracy", labels = scales::percent) +
  ggtitle("Posterior predictive check: After awake rest")
  
plot_ppc_day1b

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "ppc_day1b.pdf"),
    plot = plot_ppc_day1b,
    width = 6, height = 3,
    units = "in", dpi = 300
  )
}
```


# Held-out trials

The primary analyses, including the parameter-fitting, were performed on a set of trials where there was always one unambiguously correct answer. For this reason, there was also a subset of trials that got "held out" because the two options had the same shortest path distance from the target.

Several of the computational models (i.e., BFS-backward and ideal observer) therefore predict 50/50% choices on these trials. BFS-forward does not, as it allows for stochasticity in how agents perform searches from each of the two options. Most notably, the SR often predicts that an agent will prefer one option over another, which basically reflects the fact that (e.g.) although Sources A and B have the same shortest path distance to the Target, Source A might have a greater number of short paths to the Target.

Therefore, if subjects' behaviors are consistent with non-random responding, we'd ideally like to see that the SR is able to make more accurate out-of-sample predictions on these trials.

```{r calc-heldout-likelihoods}
heldout_likelihoods <- ppc_sr_navigation %>%
  filter(two_correct_options == TRUE) %>%
  select(
    study, sub_id, measurement_id,
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    opt1_sr, opt2_sr, sub_choice, softmax_temperature
  ) %>%
  mutate(
    sr_prefers = case_when(
      opt1_sr == opt2_sr ~ NA_real_,
      opt1_sr > opt2_sr ~ opt1_id,
      TRUE ~ opt2_id
    )
  ) %>%
  # Calculate the likelihood of the subject's choice, given what option
  # the SR would have preferred
  rowwise() %>%
  mutate(
    p_sub_choice = softmax(
      option_values = c(opt1_sr, opt2_sr),
      option_chosen = if_else(sr_prefers == opt1_id, 1, 2),
      temperature = softmax_temperature,
      use_inverse_temperature = TRUE
    ),
    # Fix a few edge cases
    p_sub_choice = case_when(
      is.na(sr_prefers) ~ 0.5,
      is.nan(p_sub_choice) & (sub_choice == sr_prefers) ~ 1,
      # To avoid log(0), use machine epsilon
      is.nan(p_sub_choice) & (sub_choice != sr_prefers) ~ 2.22e-16,
      TRUE ~ p_sub_choice
    )
  ) %>%
  ungroup() %>%
  mutate(neg_ll_sr = neg_loglik_logistic(p_sub_choice)) %>%
  # Tidy up the SR bit
  select(
    study, sub_id, measurement_id,
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    sr_prefers, sub_choice,
    neg_ll_sr
  ) %>%
  # Add likelihoods for BFS-backward, which always predicts 50/50 responding
  mutate(neg_ll_bfs_backward = neg_loglik_logistic(0.5)) %>%
  # Add likelihoods for BFS-forward
  left_join(
    ppc_bfs_forward %>%
      select(
        study, sub_id, measurement_id,
        shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
        p_bfs_chooses_opt1, p_complete_bfs, p_give_up
      ),
    by = join_by(
      study, sub_id, measurement_id,
      shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id
    )
  ) %>%
  mutate(
    p_sub_choice = case_when(
      is.na(sr_prefers) ~ 0.5,
      sr_prefers == opt1_id ~ (
        (p_complete_bfs * p_bfs_chooses_opt1) + (p_give_up * 1/2)
      ),
      sr_prefers == opt2_id ~ (
        (p_complete_bfs * (1-p_bfs_chooses_opt1)) + (p_give_up * 1/2)
      )
    ),
    neg_ll_bfs_forward = neg_loglik_logistic(p_sub_choice)
  ) %>%
  # Add likelihoods for ideal observer, which always predicts 50/50 responding
  mutate(neg_ll_ideal_obs = neg_loglik_logistic(0.5)) %>%
  # Tidy up
  select(
    study, sub_id, measurement_id,
    shortest_path, startpoint_id, endpoint_id, opt1_id, opt2_id,
    sr_prefers, sub_choice,
    neg_ll_sr, neg_ll_bfs_backward, neg_ll_bfs_forward, neg_ll_ideal_obs
  )
```

Below, we see that the SR, compared to the other models, is doing a better job of explaining subjects' choices on held-out trials, though we again note that the likelihoods for the BFS-backward and ideal observer models are for completely random responding. The different studies have different baseline log-likelihoods because they contain a different number of trials (i.e., in Studies 2-3, we're summing over both Day 1 and Day 2 measurements).

```{r plot-heldout-likelihoods}
plot_heldout_likelihoods <- heldout_likelihoods %>%
  # Sum so that we get one neg-loglik per subject
  group_by(study, sub_id, measurement_id) %>%
  summarise(across(starts_with("neg_ll_"), sum), .groups = "drop") %>%
  # Prep for plotting
  pivot_longer(
    starts_with("neg_ll_"), names_to = "model", values_to = "neg_ll"
  ) %>%
  mutate(model = str_remove(model, "neg_ll_"), loglik = -neg_ll) %>%
  ggplot(aes(x=measurement_id, y=loglik, color=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  stat_summary(
    geom = "crossbar", fun = median, position = position_dodge(width = 0.75)
  ) +
  geom_point(
    alpha = 0.2,
    position = position_jitterdodge(
      jitter.width = 0.1, jitter.height = 0, dodge.width = 0.75, seed = 1
    ),
    show.legend = FALSE
  ) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(name = "log-likelihood\n(greater = better)") +
  scale_color_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr" = "Successor Rep."
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr" = "#af8dc3"
    )
  ) +
  theme(legend.position = "bottom") +
  ggtitle("Held-out trials: log-likelihoods")

plot_heldout_likelihoods

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "heldout_likelihoods.pdf"),
    plot = plot_heldout_likelihoods,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
```

Although the likelihoods give us a nice quantitative metric, we may also be interested in knowing how well the SR predicts subjects' choices just in terms of accuracy. In the below analysis, we'll use mixed-effects logistic regression to test whether subjects significantly choose the SR-preferred option. Note that we're removing all of the trials where the SR is indifferent to the two options, as those trials lead us to overestimate the model's predicted accuracy (i.e., because the subject is always right on those trials). Note also that we're iteratively re-parameterizing the model with a different reference category each time, so that we can test whether model accuracy is significantly above chance at each distance. Finally, note that for the purpose of statistical testing, we're looking at the two main measurements: before rest (day 1), and after overnight rest (day 2).

```{r test-heldout-accuracy}
stats_heldout_accuracy_dist2 <- heldout_likelihoods %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  filter(!is.na(sr_prefers)) %>%
  mutate(
    p_sub_chooses_sr_preference = sub_choice == sr_prefers,
    sub_id = str_c(study, ", s", sub_id)
  ) %>%
  glmmTMB(
    p_sub_chooses_sr_preference ~ shortest_path +
      (1 | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_heldout_accuracy_dist3 <- heldout_likelihoods %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  filter(!is.na(sr_prefers)) %>%
  mutate(
    p_sub_chooses_sr_preference = sub_choice == sr_prefers,
    sub_id = str_c(study, ", s", sub_id),
    shortest_path = fct_relevel(shortest_path, "3")
  ) %>%
  glmmTMB(
    p_sub_chooses_sr_preference ~ shortest_path +
      (1 | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

stats_heldout_accuracy_dist4 <- heldout_likelihoods %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  filter(!is.na(sr_prefers)) %>%
  mutate(
    p_sub_chooses_sr_preference = sub_choice == sr_prefers,
    sub_id = str_c(study, ", s", sub_id),
    shortest_path = fct_relevel(shortest_path, "4")
  ) %>%
  glmmTMB(
    p_sub_chooses_sr_preference ~ shortest_path +
      (1 | sub_id) + (1 | study),
    family = binomial,
    data = .
  )

map_dfr(
  .x = list(
    "dist2" = stats_heldout_accuracy_dist2,
    "dist3" = stats_heldout_accuracy_dist3,
    "dist4" = stats_heldout_accuracy_dist4
  ),
  .f = ~tidy(.x) %>% select(-component),
  .id = "ref_cat"
) %>%
  kable_custom(
    "Held-out trials: SR model accuracy",
    grouping_var = ref_cat
  )
```

In the plot below, we can see that there's quite a bit of variability across subjects (black lines and datapoints), but also that at the group level, the SR achieves above-chance accuracy for the held-out distance-3 and distance-4 trials (red).

```{r plot-heldout-accuracy}
predict_heldout_accuracy <- tibble(
  shortest_path = factor(2:4), sub_id = NA, study = NA
) %>%
  predict_glmmTMB(stats_heldout_accuracy_dist2)

plot_heldout_accuracy <- heldout_likelihoods %>%
  filter(measurement_id %in% c("D1", "D2")) %>%
  filter(!is.na(sr_prefers)) %>%
  group_by(study, sub_id, shortest_path) %>%
  summarise(
    p_sub_chooses_sr_preference = mean(sub_choice == sr_prefers),
    .groups = "drop"
  ) %>%
  ggplot(aes(x=shortest_path, y=p_sub_chooses_sr_preference)) +
  theme_custom() +
  geom_hline(yintercept = 0.5, linetype = "dashed") +
  geom_line(aes(group = interaction(study, sub_id)), alpha = 0.1) +
  geom_point(alpha = 0.1) +
  geom_pointrange(
    aes(x = shortest_path, y = fit, ymin = fit - se.fit, ymax = fit + se.fit),
    data = predict_heldout_accuracy, inherit.aes = FALSE,
    # fatten = 1, size = 1,
    linewidth = 1, color = "red"
  ) +
  scale_x_discrete(name = "Shortest path distance") +
  scale_y_continuous(
    name = "p(Human chooses SR-preferred Source)",
    labels = scales::percent,
    breaks = seq(0, 1, .25),
    expand = expansion(mult = c(0.1, 0.15))
  ) +
  ggtitle("Held-out trials: SR model accuracy")

plot_heldout_accuracy

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "heldout_accuracy.pdf"),
    plot = plot_heldout_accuracy,
    width = 4, height = 5,
    units = "in", dpi = 300
  )
}
```


# McFadden's pseudo R-squared

Finally, it can be useful to get a sense for each model's "goodness-of-fit" by computing the ratio of its likelihood to the likelihood of a null model (i.e., McFadden's R-squared). Here, the null model is an agent that chooses completely at random on every trial.

```{r calc-mcfadden-r2}
mcfadden_r2 <- aicc %>%
  filter(model != "sr_delta_rule") %>%
  mutate(loglik = -neg_loglik) %>%
  select(study, sub_id, measurement_id, model, model_loglik = loglik) %>%
  mutate(
    # Null model is chance-level choice on every trial, so this ends up being
    # exactly equal to the usual formulation
    null_loglik = nav_trials %>%
      filter(two_correct_options == FALSE) %>%
      nrow() %>%
      {. * log(0.5)}
  ) %>%
  mutate(mcfadden_r2 = 1 - (model_loglik / null_loglik))

mcfadden_r2 %>%
  group_by(study, measurement_id, model) %>%
  summarise(Mean = mean(mcfadden_r2), .groups = "drop") %>%
  pivot_wider(names_from = model, values_from = Mean) %>%
  kable_custom("Mean McFadden's R2", grouping_var = study)

mcfadden_r2 %>%
  group_by(study, measurement_id, model) %>%
  summarise(Median = median(mcfadden_r2), .groups = "drop") %>%
  pivot_wider(names_from = model, values_from = Median) %>%
  kable_custom("Median McFadden's R2", grouping_var = study)
```

We see in the plot that, generally, all models are doing better than the null model, and also that the SR consistently seems to have the best likelihood ratio.

```{r plot-mcfadden-r2}
plot_mcfadden_r2 <- mcfadden_r2 %>%
  ggplot(aes(x=measurement_id, y=mcfadden_r2, color=model)) +
  theme_custom() +
  facet_wrap(~study, scales = "free_x") +
  geom_point(
    alpha = 0.1,
    position = position_jitterdodge(
      jitter.width = 0.1, jitter.height = 0, dodge.width = 0.75, seed = 1
    ),
    show.legend = FALSE
  ) +
  stat_summary(
    geom = "crossbar", fun = median, position = position_dodge(width = 0.75)
  ) +
  scale_x_discrete(
    name = NULL,
    labels = c(
      "D1"="Before\nrest",
      "D2"="After\novernight\nrest",
      "D1b"="After\nawake\nrest"
    )
  ) +
  scale_y_continuous(name = bquote(~"McFadden's"~R^2)) +
  scale_color_manual(
    name = NULL,
    labels = c(
      "bfs_backward" = "BFS-backward",
      "bfs_forward" = "BFS-forward",
      "ideal_obs" = "Ideal observer",
      "sr_analytic" = "Successor Representation"
    ),
    values = c(
      "bfs_backward" = "#a6dba0",
      "bfs_forward" = "#5aae61",
      "ideal_obs" = "#1b7837",
      "sr_analytic" = "#af8dc3"
    )
  ) +
  theme(legend.position = "bottom") +
  ggtitle("Model goodness-of-fit")

plot_mcfadden_r2

if (knitting) {
  ggsave(
    filename = here("outputs", workflow_name, "mcfadden_r2.pdf"),
    plot = plot_mcfadden_r2,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
```


# Figures for supplement

```{r plot-model-comparison-for-supp}
plot_model_comparison_for_supp <- wrap_plots(
  plot_akaike_group,
  plot_best_fitting_model_prop,
  plot_pxp,
  guides = "collect", nrow = 1
) +
  plot_annotation(
    title = "Model comparison",
    tag_levels = "A", tag_suffix = ".",
    theme = theme(plot.title = element_text(hjust = 0.5))
  ) &
  theme(legend.position = "bottom")

plot_model_comparison_for_supp

if (knitting) {
  ggsave(
    filename = here("figures", "supp_model_comparison.pdf"),
    plot = plot_model_comparison_for_supp,
    width = 16, height = 4,
    units = "in", dpi = 300
  )
}
```

```{r plot-ppc-for-supp}
plot_ppc_for_supp <- wrap_plots(
  plot_ppc_day2 + ggtitle("After overnight rest (Studies 2-3)"),
  plot_ppc_day1b + ggtitle("After awake rest (Study 3)"),
  ncol = 1
) +
  plot_annotation(
    title = "Posterior predictive check",
    tag_levels = "A", tag_suffix = ".",
    theme = theme(plot.title = element_text(hjust = 0.5))
  ) &
  scale_y_continuous(
    name = "Accuracy", labels = scales::percent,
    limits = c(.25, 1),
    breaks = seq(.25, 1, .25)
  )

plot_ppc_for_supp

if (knitting) {
  ggsave(
    filename = here("figures", "supp_ppc.pdf"),
    plot = plot_ppc_for_supp,
    width = 8, height = 6,
    units = "in", dpi = 300
  )
}
```

There are some plots that we want to use as-is, so we'll save a redundant copy in the `figures` folder.

```{r save-redundant-copies}
if (knitting) {
  ggsave(
    filename = here("figures", "supp_param_estimates.pdf"),
    plot = plot_params_all,
    width = 12, height = 6,
    units = "in", dpi = 300
  )
  
  ggsave(
    filename = here("figures", "supp_akaike_weights_per_subject.pdf"),
    plot = plot_akaike_individual,
    width = 8, height = 10,
    units = "in", dpi = 300
  )
  
  ggsave(
    filename = here("figures", "supp_heldout_likelihoods.pdf"),
    plot = plot_heldout_likelihoods,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
  
  ggsave(
    filename = here("figures", "supp_heldout_accuracy.pdf"),
    plot = plot_heldout_accuracy,
    width = 4, height = 5,
    units = "in", dpi = 300
  )
  
  ggsave(
    filename = here("figures", "supp_mcfadden_r2.pdf"),
    plot = plot_mcfadden_r2,
    width = 6, height = 4,
    units = "in", dpi = 300
  )
}
```

